#!/usr/bin/python
#
# Copyright (c) 2015 Juniper Networks, Inc. All rights reserved.
#

import sys
import argparse
import requests
from requests.exceptions import ConnectionError
import json
import os
import errno
import datetime
import re
import time
import socket

PORT_PATH = "/var/lib/contrail/ports/"
PORT_IPC_VROUTER_AGENT_PORT = 9091
TIMEOUT_SECS = 15
DPDK_NETLINK_TCP_PORT = 20914
PORT_OPEN_TIMEOUT_SECS = 120
DPDK_NETLINK_SOCK_PATH = "/var/run/vrouter/dpdk_netlink"
LINK_STATE = "link-state"
ESXI_PORT_TYPE = 2


class VrouterPortControl(object):

    def __init__(self, args_str=None):
        self._args = None
        if not args_str:
            args_str = ' '.join(sys.argv[1:])
        self._parse_args(args_str)
        # TODO: Move this to other place as it tries to create directory every time
        self.CreateDirectoryTree(PORT_PATH)
        port_type_value = 0
        ret_code = 0
        headers = {'content-type': 'application/json'}
        base_url = "http://localhost:" + str(PORT_IPC_VROUTER_AGENT_PORT)
        if self._args.oper == "add":
            if self._args.vif_type == "VhostUser":
                if not self._args.vhostuser_socket:
                    sys.exit(1)
                if not self.WaitForPortOpen(PORT_IPC_VROUTER_AGENT_PORT):
                    sys.exit(2)
                if not self.WaitForNetlinkSocket():
                    sys.exit(3)

            if self._args.port_type == "NovaVMPort":
                port_type_value = 0
            elif self._args.port_type == "NameSpacePort":
                port_type_value = 1
            elif self._args.port_type == "ESXiPort":
                port_type_value = 2

            u = self._args.vm_project_uuid
            project_id = ""
            if (u and (len(u) == 32)):
                u = u[:8] + '-' + u[8:]
                u = u[:13] + '-' + u[13:]
                u = u[:18] + '-' + u[18:]
                project_id = u[:23] + '-' + u[23:]

            url = base_url + "/port"
            payload = self.GetJSonDict(port_type_value, project_id)
            json_dump = json.dumps(payload)
            if not self._args.no_persist:
                ret_code = self.WriteToFile(payload)
            try:
                '''We post the request to agent even if WriteToFile has failed.
                   Agent will write to file if file is not already present
                '''
                r = requests.post(url, data=json_dump, headers=headers)
                '''Overwrite the ret_code based on agent response '''
                if r.status_code != 200:
                    print "Request failed: %s" % r.text
                    ret_code = 4
                else:
                    ret_code = 0
                if ret_code == 0 and self._args.vif_type == "VhostUser" and \
                        self._args.vhostuser_mode == 0:
                    #In DPDK mode we should wait until socket is created by
                    #vrouter module
                    ret_code  = self.WaitForSocketFile()

            except ConnectionError:
                '''In DPDK mode return failure when we are not able to connect
                   to agent
                '''
                if self._args.vif_type == "VhostUser":
                    ret_code = 5
                pass
        elif self._args.oper == "delete":
            if not self._args.no_persist:
                self.DeleteFile()
            url = base_url + "/port/" + self._args.uuid
            try:
                r = requests.delete(url, data=None, headers=headers)
                if r.status_code != 200:
                    print "Request failed: %s" % r.text
                    ret_code = 1
            except ConnectionError:
                pass
        elif self._args.oper == "enable":
            if not self._args.no_persist:
                port_data = self.ReadFromFile()
                if port_data is not None and self.IsESXiPort(port_data):
                    port_data[LINK_STATE] = 1
                    self.WriteToFile(port_data)
            url = base_url + "/enable-port/" + self._args.uuid
            try:
                r = requests.put(url, data=None, headers=headers)
                if r.status_code != 200:
                    print "Request failed: %s" % r.text
                    ret_code = 1
            except ConnectionError:
                pass
        elif self._args.oper == "disable":
            if not self._args.no_persist:
                port_data = self.ReadFromFile()
                if port_data is not None and self.IsESXiPort(port_data):
                    if LINK_STATE in port_data:
                        port_data.pop(LINK_STATE)
                        self.WriteToFile(port_data)
            url = base_url + "/disable-port/" + self._args.uuid
            try:
                r = requests.put(url, data=None, headers=headers)
                if r.status_code != 200:
                    print "Request failed: %s" % r.text
                    ret_code = 1
            except ConnectionError:
                pass

        sys.exit(ret_code)
    # end __init__

    def StripQuotes(self, arg):
        if arg and arg.startswith('"') and arg.endswith('"'):
            return arg[1:-1]
    # end StripQuotes

    def IsNumber(self, s):
        try:
            int(s)
            return True
        except ValueError:
            return False
    # end IsNumber

    def _parse_args(self, args_str):
        strip_quotes = False
        if '"' in args_str:
            regex = re.compile("\s+(?=\-\-[_a-zA-Z0-9]+=\".*?\")")
            strip_quotes = True
        else:
            regex = re.compile("\s+(?=\-\-)")
        # Turn off help, so we print all options in response to -h
        conf_parser = argparse.ArgumentParser(add_help=False)
        args, remaining_argv = conf_parser.parse_known_args(
            regex.split(args_str)
        )

        # Don't surpress add_help here so it will handle -h
        parser = argparse.ArgumentParser(
            # Inherit options from config_parser
            parents=[conf_parser],
            # print script description with -h/--help
            description=__doc__,
            # Don't mess with format of description
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )

        parser.add_argument('--oper', help="Operation add/delete/enable/disable ",
                            required=True)
        parser.add_argument('--uuid', help="port UUID", required=True)
        parser.add_argument('--instance_uuid', help="instance UUID")
        parser.add_argument('--vn_uuid', help="VN UUID")
        parser.add_argument('--vm_project_uuid', help="VM UUID")
        parser.add_argument('--ip_address', help="IP Address")
        parser.add_argument('--ipv6_address', help="IPv6 Address")
        parser.add_argument('--vm_name', help="VM Name")
        parser.add_argument('--mac', help="MAC address")
        parser.add_argument('--tap_name', help="System name of interface")
        parser.add_argument("--port_type", help="Port type", default="NovaVMPort")
        parser.add_argument("--tx_vlan_id", help="Transmit VLAN ID")
        parser.add_argument("--rx_vlan_id", help="Receive VLAN ID")
        parser.add_argument("--no_persist", type=bool,
                            help="Dont't store port information in files",
                            default=False)
        parser.add_argument("--vif_type", help="VIF type", default="Vrouter")
        parser.add_argument("--vhostuser_socket", help="Path of vhostuser socket file")
        parser.add_argument("--vhostuser_mode",
                            help = "mode of operation of vhostuser socket",
                            default="0")

        self._args = parser.parse_args(remaining_argv)

        if strip_quotes:
            self._args.oper = self.StripQuotes(self._args.oper)
            self._args.uuid = self.StripQuotes(self._args.uuid)
            self._args.instance_uuid = self.StripQuotes(self._args.instance_uuid)
            self._args.ip_address = self.StripQuotes(self._args.ip_address)
            self._args.ipv6_address = self.StripQuotes(self._args.ipv6_address)
            self._args.vn_uuid = self.StripQuotes(self._args.vn_uuid)
            self._args.vm_name = self.StripQuotes(self._args.vm_name)
            self._args.vm_project_uuid = self.StripQuotes(self._args.vm_project_uuid)
            self._args.mac = self.StripQuotes(self._args.mac)
            self._args.tap_name = self.StripQuotes(self._args.tap_name)
            self._args.port_type = self.StripQuotes(self._args.port_type)
            self._args.rx_vlan_id = self.StripQuotes(self._args.rx_vlan_id)
            self._args.tx_vlan_id = self.StripQuotes(self._args.tx_vlan_id)
            self._args.vif_type = self.StripQuotes(self._args.vif_type)
            self._args.vhostuser_socket = self.StripQuotes(self._args.vhostuser_socket)
            self._args.vhostuser_mode = self.StripQuotes(self._args.vhostuser_mode)
            if self._args.no_persist:
                self._args.no_persist = True

        oper_list = ['add', 'delete', 'enable', 'disable']
        if self._args.oper not in oper_list:
            print "Invalid argument for oper %s" % (self._args.oper)
            sys.exit(1)

        if self._args.oper == "add":
            port_type_list = ['NovaVMPort', 'NameSpacePort', 'ESXiPort']
            if self._args.port_type not in port_type_list:
                print "Invalid argument for port_type %s" % (self._args.port_type)
                sys.exit(1)

            vif_type_list = ['VhostUser', 'Vrouter']
            if self._args.vif_type not in vif_type_list:
                print "Invalid argument for vif_type %s" % (self._args.vif_type)
                sys.exit(1)

            if self._args.rx_vlan_id:
                if not self.IsNumber(self._args.rx_vlan_id):
                    print "Invalid argument for rx_vlan_id %s" % (self._args.rx_vlan_id)
                    sys.exit(1)
                self._args.rx_vlan_id = int(self._args.rx_vlan_id)

            if self._args.tx_vlan_id:
                if not self.IsNumber(self._args.tx_vlan_id):
                    print "Invalid argument for tx_vlan_id %s" % (self._args.tx_vlan_id)
                    sys.exit(1)
                self._args.tx_vlan_id = int(self._args.tx_vlan_id)

            if self._args.vhostuser_mode:
                self._args.vhostuser_mode = int(self._args.vhostuser_mode)

    #end _parse_args

    def CreateDirectoryTree(self, path):
        try:
            os.makedirs(path)
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise
    # end CreateDirectoryTree

    def GetJSonDict(self, port_type, project_id):
        data = {
            "id": self._args.uuid,
            "instance-id": self._args.instance_uuid,
            "ip-address": self._args.ip_address,
            "ip6-address": self._args.ipv6_address,
            "vn-id": self._args.vn_uuid,
            "display-name": self._args.vm_name,
            "vm-project-id": project_id,
            "mac-address": self._args.mac,
            "system-name": self._args.tap_name,
            "type": port_type,
            "rx-vlan-id": self._args.rx_vlan_id,
            "tx-vlan-id": self._args.tx_vlan_id,
            "vhostuser-mode" : self._args.vhostuser_mode,
            "author": __file__,
            "time": str(datetime.datetime.now())
        }
        return data
    # end GetJSonDict

    def IsESXiPort(self, port_data):
        return port_data["type"] == ESXI_PORT_TYPE
    # end IsESXiPort


    def WriteToFile(self, data):
        filename = ("%s%s" % (PORT_PATH, self._args.uuid))
        try:
            with open(filename, 'w') as outfile:
                json.dump(data, outfile, True)
                outfile.close()
        except:
            return 1
        return 0
    # end WriteToFile

    def DeleteFile(self):
        filename = ("%s%s" % (PORT_PATH, self._args.uuid))
        if os.path.isfile(filename):
            os.remove(filename)
    # end DeleteFile

    def ReadFromFile(self):
        filename = ("%s%s" % (PORT_PATH, self._args.uuid))
        try:
            with open(filename, 'r') as port_file:
                json_data = port_file.read()
                return json.loads(json_data)
        except:
            return None
    # end ReadFromFile

    def WaitForNetlinkSocket(self):
        timeout_usecs = TIMEOUT_SECS * 1000000
        sleep_usecs = 1000
        for i in range(1,(timeout_usecs/sleep_usecs)):
            if os.path.exists(DPDK_NETLINK_SOCK_PATH):
                #TODO: need to check if the socket path is opened
                return True
            time.sleep(sleep_usecs/1000000.0)
        return False
    #end WaitForNetlinkSocket

    def WaitForSocketFile(self):
        timeout_usecs = TIMEOUT_SECS * 1000000
        sleep_usecs = 1000
        for i in range(1, (timeout_usecs / sleep_usecs)):
            if os.path.exists(self._args.vhostuser_socket):
                return 0
            # sleep takes time in seconds. Convert usecs to secs.
            time.sleep(sleep_usecs / 1000000.0)
        return 6
    # end WaitForSocketFile

    def IsPortOpen(self, port_num):
        s = socket.socket()
        try:
            s.connect(('127.0.0.1', port_num))
            s.shutdown(socket.SHUT_RDWR)
            s.close()
            return True
        except socket.error:
            return False
    # end IsPortOpen

    def WaitForPortOpen(self, port_num):
        timeout_usecs = PORT_OPEN_TIMEOUT_SECS * 1000000
        sleep_usecs = 1000000
        for i in range(1, (timeout_usecs / sleep_usecs)):
            if self.IsPortOpen(port_num):
                return True
            # sleep takes time in seconds. Convert usecs to secs.
            time.sleep(sleep_usecs / 1000000.0)
        return False
    # end WaitForPortOpen
# end class VrouterPortControl


def main(args_str=None):
    VrouterPortControl(args_str)
# end main


if __name__ == "__main__":
    main()
